#include <unordered_map>
#include <string>
#include "Hooker.h"

//#define PRINT_ERROR			// MessageBox installing/activating function hooks (user32.dll dependency)
#define SETTAFF_CORRECT_RET		// Return expected previous thread affinity mask
//#define AFFTRANSL_GREATER		// Only translate affinities > cpu.affinityMask

#define FAKE_CORES		((DWORD)8)
static_assert(FAKE_CORES > 0,"Invalid FAKE_CORES");

DWORD pow2(DWORD exp) {
	//return (exp > 0) ? (2 * pow2(exp-1)) : (1);
	if(exp < 1) {
		return 1;
	}
	DWORD ret = 1;
	for(DWORD i = 0; i < exp; ++i) {
		ret *= 2;
	}
	return ret;
}

DWORD const kFakeMask = pow2(FAKE_CORES) - 1;

struct Cpu {
	DWORD cores;
	DWORD affinityMask;
	DWORD glpiLength;

	enum LogicalInfoRel {
		kCore0 = 0,
		kCache0,
		kPackage,
		kSharedCache,
		kNuma,
		_LIR_COUNT_
	};

	SYSTEM_LOGICAL_PROCESSOR_INFORMATION logicalCpuInfo[_LIR_COUNT_];

	Cpu() : cores(0),affinityMask(0),glpiLength(0) {}

	bool setCoresInfo() {
		if(cores > 0) {
			return true;
		}
		SYSTEM_INFO info;
		GetSystemInfo(&info);
		cores = info.dwNumberOfProcessors;
		affinityMask = (cores > 0) ? (pow2(cores) - 1) : (0);

		for(unsigned char i = 0; i < _LIR_COUNT_; ++i) {
			logicalCpuInfo[i].ProcessorMask = kFakeMask;
		}
		DWORD coreLogicalInfoSize = sizeof(SYSTEM_LOGICAL_PROCESSOR_INFORMATION) * 3; // core/cache/cache
		// ((core0/cache0/cache0) * FAKE_CORES)/package/sharedCache/numa 
		// ((3 * FAKE_CORES) + 3) * sizeof(SYSTEM_LOGICAL_PROCESSOR_INFORMATION)
		glpiLength = (FAKE_CORES * coreLogicalInfoSize) + (sizeof(SYSTEM_LOGICAL_PROCESSOR_INFORMATION) * 3);
		PSYSTEM_LOGICAL_PROCESSOR_INFORMATION buffer = nullptr;
		DWORD returnLength = 0;
		while(GetLogicalProcessorInformation(buffer,&returnLength) == FALSE) {
			if(GetLastError() == ERROR_INSUFFICIENT_BUFFER) {
				if(buffer) {
					delete buffer;
				}
				buffer = (PSYSTEM_LOGICAL_PROCESSOR_INFORMATION)new unsigned char[returnLength];
				if(!buffer) {
					returnLength = 0;
					break;
				}
			}
			else {
				returnLength = 0;
				break;
			}
		}

		if(returnLength > 0) {
			size_t byteOffset = 0;
			PSYSTEM_LOGICAL_PROCESSOR_INFORMATION ptr = buffer;
			while(byteOffset + sizeof(SYSTEM_LOGICAL_PROCESSOR_INFORMATION) <= returnLength) {
				switch(ptr->Relationship)
				{
				case RelationNumaNode:
					// Non-NUMA systems report a single record of this type. // First comments from MSDN example
					if(ptr->ProcessorMask & 1) {
						logicalCpuInfo[kNuma] = *ptr;
						logicalCpuInfo[kNuma].ProcessorMask = kFakeMask;
					}
					break;
				case RelationProcessorCore:
					// A hyperthreaded core supplies more than one logical processor.
					if(ptr->ProcessorMask == 1 && ptr->ProcessorCore.Flags == 0) {
						// Copy core0
						logicalCpuInfo[kCore0] = *ptr;
					}
					break;
				case RelationCache:
					// Cache data is in ptr->Cache, one CACHE_DESCRIPTOR structure for each cache. 
					if(ptr->ProcessorMask == 1 && ptr->Cache.Level == 1) {
						// Copy L1 cache from core0
						logicalCpuInfo[kCache0] = *ptr;
					}
					else if((ptr->ProcessorMask & 1) && ptr->ProcessorMask > 1 && ptr->Cache.Level > 1) {
						// Copy any L2/L3 cache shared with core0
						logicalCpuInfo[kSharedCache] = *ptr;
						logicalCpuInfo[kSharedCache].ProcessorMask = kFakeMask;
					}
					break;
				case RelationProcessorPackage:
					// Logical processors share a physical package.
					if(ptr->ProcessorMask & 1) {
						// Copy the package where core0 resides
						logicalCpuInfo[kPackage] = *ptr;
						logicalCpuInfo[kPackage].ProcessorMask = kFakeMask;
					}
					break;
				default:
					break;
				}

				byteOffset += sizeof(SYSTEM_LOGICAL_PROCESSOR_INFORMATION);
				ptr++;
			}
			delete[] buffer;
		} // if(returnLength > 0)
		return true;
	} // setCoresInfo

} cpu;

///////////////////////////////////////////////////////////////////////////////
//////////////////////////////// GetSystemInfo ////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
void WINAPI _GetSystemInfo(__out LPSYSTEM_INFO lpSystemInfo) {
	if(!lpSystemInfo || cpu.cores >= FAKE_CORES) {
		return GetSystemInfo(lpSystemInfo);
	}
	SYSTEM_INFO info;
	GetSystemInfo(&info);
	info.dwNumberOfProcessors = FAKE_CORES;
	info.dwActiveProcessorMask = kFakeMask;
	*lpSystemInfo = info;
	// alternative
	/*
	GetSystemInfo(lpSystemInfo);
	lpSystemInfo->dwNumberOfProcessors = FAKE_CORES;
	lpSystemInfo->dwActiveProcessorMask = kFakeMask;
	*/
}

///////////////////////////////////////////////////////////////////////////////
//////////////////////// GetLogicalProcessorInformation ///////////////////////
///////////////////////////////////////////////////////////////////////////////
BOOL WINAPI _GetLogicalProcessorInformation(__out PSYSTEM_LOGICAL_PROCESSOR_INFORMATION Buffer,
											__inout PDWORD ReturnLength) {
	if(!Buffer || !ReturnLength || cpu.cores >= FAKE_CORES) {
		return GetLogicalProcessorInformation(Buffer,ReturnLength);
	}
	if(*ReturnLength < cpu.glpiLength) {
		*ReturnLength = cpu.glpiLength;
		return FALSE;
	}
	PSYSTEM_LOGICAL_PROCESSOR_INFORMATION structPtr = Buffer;
	for(DWORD core = 0; core < FAKE_CORES; ++core) {
		DWORD processorMask = pow2(core);

		cpu.logicalCpuInfo[Cpu::kCore0].ProcessorMask = processorMask;
		*structPtr = cpu.logicalCpuInfo[Cpu::kCore0];
		++structPtr;

		cpu.logicalCpuInfo[Cpu::kCache0].ProcessorMask = processorMask;
		*structPtr = cpu.logicalCpuInfo[Cpu::kCache0];
		++structPtr;
		*structPtr = cpu.logicalCpuInfo[Cpu::kCache0];
		++structPtr;
	}
	// The processor mask of the last three structs is kFakeMask
	*structPtr = cpu.logicalCpuInfo[Cpu::kPackage];
	++structPtr;
	*structPtr = cpu.logicalCpuInfo[Cpu::kSharedCache];
	++structPtr;
	*structPtr = cpu.logicalCpuInfo[Cpu::kNuma];

	return TRUE;
}


///////////////////////////////////////////////////////////////////////////////
//////////////////////////// SetThreadAffinityMask ////////////////////////////
///////////////////////////////////////////////////////////////////////////////
#ifdef SETTAFF_CORRECT_RET

DWORD_PTR WINAPI _SetThreadAffinityMask(__in HANDLE hThread,
										__in DWORD_PTR dwThreadAffinityMask) {
	// On success SetThreadAffinityMask returns the previous mask, if the call were remapped
	// the return value will be different from the mask requested. We store the previous
	// correct mask for hThread and return it in the next call.
	static std::unordered_map<HANDLE,DWORD_PTR> threadActualPrevMask;
	DWORD_PTR prevMask = 0;
	// SetThreadAffinityMask() == 0 -> error 
	#ifdef AFFTRANSL_GREATER
	{
		// Asignment to core > number of actual cores gets assigned to all actual cores
		if(cpu.affinityMask < dwThreadAffinityMask) {
			//TODO: prevMask == 0 -> Error using the modified affinity mask!
			prevMask = SetThreadAffinityMask(hThread,cpu.affinityMask);
		}
		else { // Otherwise unmodified call
			prevMask = SetThreadAffinityMask(hThread,dwThreadAffinityMask);
		}
	}
	#else
	{
		// Translate all masks. This may help to redistribute unbalanced core usage.
		prevMask = SetThreadAffinityMask(hThread,cpu.affinityMask);
	}
	#endif

	if(prevMask != 0) {
		auto it = threadActualPrevMask.find(hThread);
		if(it != threadActualPrevMask.end()) {
			prevMask = it->second;
			it->second = dwThreadAffinityMask;
		}
		else {
			threadActualPrevMask[hThread] = dwThreadAffinityMask;
		}
	}

	return prevMask;

	///////////////////////////
	// alternative
	// even bits in dwThreadAffinityMask to even bits in [0,cpu.cores-1]
	// odd " 
}

#else

DWORD_PTR WINAPI _SetThreadAffinityMask(__in HANDLE hThread,
										__in DWORD_PTR dwThreadAffinityMask) {
	if(cpu.affinityMask < dwThreadAffinityMask) {
		return SetThreadAffinityMask(hThread,cpu.affinityMask);
	}
	#ifdef AFFTRANSL_GREATER
	{
		if(cpu.affinityMask < dwThreadAffinityMask) {
			return SetThreadAffinityMask(hThread,cpu.affinityMask);
		}
		return SetThreadAffinityMask(hThread,dwThreadAffinityMask);
	}
	#else
	{
		prevMask = SetThreadAffinityMask(hThread,cpu.affinityMask);
	}
	#endif
}

#endif


//////////////////////////////////////
//////////////////////////////////////

Hooker hooker;

BOOL WINAPI DllMain(HMODULE hModuleodule,DWORD fdwReason,LPVOID) {
	switch(fdwReason) {
	case DLL_PROCESS_ATTACH:
	{
		DisableThreadLibraryCalls(hModuleodule);
		cpu.setCoresInfo();
#ifdef PRINT_ERROR
#define MSG(text,status) (MessageBoxA(NULL,std::string(text + std::to_string(status)).c_str(),"Error",0))
#define CALL_AND_MSG(method,text) (((result = method).first)? (true):(MSG(text,result.second)))

		std::pair<bool,NTSTATUS> result;
		CALL_AND_MSG(hooker.installHook("kernel32.dll","GetSystemInfo",_GetSystemInfo),"Procedure hook installation failed");
		CALL_AND_MSG(hooker.installHook("kernel32.dll","GetLogicalProcessorInformation",_GetLogicalProcessorInformation),"Procedure hook installation failed");
		CALL_AND_MSG(hooker.installHook("kernel32.dll","SetThreadAffinityMask",_SetThreadAffinityMask),"Procedure hook installation failed");
		CALL_AND_MSG(hooker.activateHook("GetSystemInfo"),"Procedure hook activation failed");
		CALL_AND_MSG(hooker.activateHook("GetLogicalProcessorInformation"),"Procedure hook activation failed");
		CALL_AND_MSG(hooker.activateHook("SetThreadAffinityMask"),"Procedure hook activation failed");
#else
		hooker.installHook("kernel32.dll","GetSystemInfo",_GetSystemInfo);
		hooker.installHook("kernel32.dll","GetLogicalProcessorInformation",_GetLogicalProcessorInformation);
		hooker.installHook("kernel32.dll","SetThreadAffinityMask",_SetThreadAffinityMask);
		hooker.activateHook("GetSystemInfo");
		hooker.activateHook("GetLogicalProcessorInformation");
		hooker.activateHook("SetThreadAffinityMask");
#endif
	}
		break;
	case DLL_PROCESS_DETACH:
		hooker.destroy();
		break;
	}

	return TRUE;
}